package com.jl.mapper;

import com.jl.JLLambda;
import com.jl.JLEmpty;
import com.jl.JLReflect;
import com.jl.JLTuple;
import com.jl.springbean.JLInvocationHandler;
import com.jl.springbean.annotation.JLProxy;
import com.mongodb.client.result.DeleteResult;
import com.mongodb.client.result.UpdateResult;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.*;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.data.repository.support.PageableExecutionUtils;
import org.springframework.util.CollectionUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * 父mapper实现
 *
 * @param <E>
 */
@JLProxy
public class JLMongoMapperImpl<E> extends JLInvocationHandler<E> implements JLMongoMapper<E> {

    public JLMongoMapperImpl(Class interfaceType) {
        super(interfaceType);
    }

    private MongoTemplate mongoTemplate() {
        return getBean(MongoTemplate.class);
    }

    /**
     * 添加
     *
     * @param entity
     * @return
     */
    @Override
    public E save(E entity) {
        return mongoTemplate().insert(entity);
    }

    /**
     * 添加或修改
     *
     * @param entity
     * @return
     */
    @Override
    public E saveOrUpdate(E entity) {
        return mongoTemplate().save(entity);
    }

    /**
     * 修改
     *
     * @param query
     * @param entity
     * @return
     */
    @Override
    public long update(Query query, E entity) {
        Update update = new Update();
        List<JLTuple.Tuple3<String, Object, Class<?>>> tuple3s = JLReflect.PropertyReflect.getProperty(entity);
        for (JLTuple.Tuple3<String, Object, Class<?>> tuple3 : tuple3s) {
            if (tuple3.getV2() == null) {
                continue;
            }
            if (tuple3.getV2() instanceof String && "".equals(tuple3.getV2())) {
                continue;
            }
            update.set(tuple3.getV1(), tuple3.getV2());
        }
        UpdateResult upsert = mongoTemplate().upsert(query, update, classa);
        return upsert.getModifiedCount();
    }

    /**
     * 删除
     *
     * @param entity
     * @return
     */
    @Override
    public long remove(E entity) {
        DeleteResult remove = mongoTemplate().remove(entity);
        return remove.getDeletedCount();
    }

    /**
     * 删除
     *
     * @param query
     * @return
     */
    @Override
    public long remove(Query query) {
        DeleteResult remove = mongoTemplate().remove(query, classa);
        return remove.getDeletedCount();
    }

    /**
     * 统计数量
     *
     * @param query
     * @return
     */
    @Override
    public long count(Query query) {
        return mongoTemplate().count(query, classa);
    }

    /**
     * 获取并修改
     *
     * @param query
     * @param entity
     * @return
     */
    @Override
    public E getAndUpdate(Query query, E entity) {
        Update update = new Update();
        List<JLTuple.Tuple3<String, Object, Class<?>>> tuple3s = JLReflect.PropertyReflect.getProperty(entity);
        for (JLTuple.Tuple3<String, Object, Class<?>> tuple3 : tuple3s) {
            if (tuple3.getV2() == null) {
                continue;
            }
            if (tuple3.getV2() instanceof String && "".equals(tuple3.getV2())) {
                continue;
            }
            update.set(tuple3.getV1(), tuple3.getV2());
        }
        return mongoTemplate().findAndModify(query, update, classa);
    }

    /**
     * 获取并删除
     *
     * @param query
     * @return
     */
    @Override
    public E getAndRemove(Query query) {
        return mongoTemplate().findAndRemove(query, classa);
    }

    /**
     * 根据id查询
     *
     * @param id
     * @return
     */
    @Override
    public E getById(Object id) {
        return mongoTemplate().findById(id, classa);
    }

    /**
     * 查询对象
     *
     * @param query
     * @return
     */
    @Override
    public E getOne(Query query) {
        return mongoTemplate().findOne(query, classa);
    }

    /**
     * 查询集合
     *
     * @param query
     * @return
     */
    @Override
    public List<E> list(Query query) {
        return mongoTemplate().find(query, classa);
    }

    /**
     * 查询分页
     *
     * @param query
     * @param pageable
     * @return
     */
    @Override
    public Page<E> page(Query query, Pageable pageable) {
        long count = count(query);
        query.with(pageable);
        List<E> list = mongoTemplate().find(query, classa);
        return PageableExecutionUtils.getPage(list, pageable, () -> count);
    }

    /**
     * 聚合查询
     *
     * @param criteria
     * @return
     */
    @Override
    public AggregateWhere<E> aggregate(Criteria criteria) {
        return new AggregateWhere<>(criteria, mongoTemplate(), classa);
    }

    public static class AggregateWhere<E> {
        private MongoTemplate mongoTemplate;
        private Class<E> t;
        private List<AggregationOperation> aggregationOperations = new ArrayList<>();

        public AggregateWhere(Criteria criteria, MongoTemplate mongoTemplate, Class<E> t) {
            this.aggregationOperations.add(Aggregation.match(criteria));
            this.mongoTemplate = mongoTemplate;
            this.t = t;
        }

        /**
         * 正序
         */
        public AggregateWhere<E> asc(JLLambda.JLFunction<E, ?> jlFunction) {
            SortOperation asc = Aggregation.sort(Sort.by(Sort.Order.asc(JLLambda.getProperty(jlFunction))));
            aggregationOperations.add(asc);
            return this;
        }

        /**
         * 倒序
         */
        public AggregateWhere<E> desc(JLLambda.JLFunction<E, ?> jlFunction) {
            SortOperation desc = Aggregation.sort(Sort.by(Sort.Order.desc(JLLambda.getProperty(jlFunction))));
            aggregationOperations.add(desc);
            return this;
        }

        /**
         * 分组 分组后会取最后一条数据
         */
        public AggregateWhere<E> group(JLLambda.JLFunction<E, ?> groupFunction, JLLambda.JLFunction<E, ?> countFunction) {
            String countProperty = JLLambda.getProperty(countFunction);
            GroupOperation group = Aggregation.group(JLLambda.getProperty(groupFunction)).count().as(JLLambda.getProperty(countFunction));
            Class<E> showClass = JLLambda.getClass(groupFunction);
            List<JLTuple.Tuple3<String, Object, Class<?>>> showClassProperty = JLReflect.PropertyReflect.getProperty(showClass);
            List<String> show = new ArrayList<>();
            for (int i = 0; i < showClassProperty.size(); i++) {
                JLTuple.Tuple3<String, Object, Class<?>> tuple3 = showClassProperty.get(i);
                String showProperty = tuple3.getV1();
                if (!showProperty.equals(countProperty)) {
                    group = group.last(showProperty).as(showProperty);
                }
                show.add(showProperty);
            }
            aggregationOperations.add(group);
            aggregationOperations.add(Aggregation.project(show.toArray(new String[show.size()])));
            return this;
        }

        /**
         * 分组 分组后会取最后一条数据
         */
        public AggregateWhere<E> group(JLLambda.JLFunction<E, ?> groupFunction) {
            String property = JLLambda.getProperty(groupFunction);
            GroupOperation group = Aggregation.group(property).last(property).as(property);
            Class<E> showClass = JLLambda.getClass(groupFunction);
            List<JLTuple.Tuple3<String, Object, Class<?>>> showClassProperty = JLReflect.PropertyReflect.getProperty(showClass);
            List<String> show = new ArrayList<>();
            for (int i = 0; i < showClassProperty.size(); i++) {
                JLTuple.Tuple3<String, Object, Class<?>> tuple3 = showClassProperty.get(i);
                String showProperty = tuple3.getV1();
                if (!showProperty.equals(property)) {
                    group = group.last(showProperty).as(showProperty);
                }
                show.add(showProperty);
            }
            aggregationOperations.add(group);
            aggregationOperations.add(Aggregation.project(show.toArray(new String[show.size()])));
            return this;
        }

        /**
         * 进入查询
         */
        public AggregateOps<E> select() {
            return new AggregateOps<>(aggregationOperations, mongoTemplate, t);
        }
    }

    public static class AggregateOps<E> {
        private String numStr = "page_count_str";
        private List<AggregationOperation> aggregationOperations;
        private MongoTemplate mongoTemplate;
        private Class<E> t;

        public AggregateOps(List<AggregationOperation> aggregationOperations, MongoTemplate mongoTemplate, Class<E> t) {
            this.aggregationOperations = aggregationOperations;
            this.mongoTemplate = mongoTemplate;
            this.t = t;
        }

        private Aggregation getAggregation(List<AggregationOperation> aggregationOperations) {
            Aggregation aggregation = Aggregation.newAggregation(aggregationOperations);
            return aggregation;
        }

        /**
         * 查询对象
         *
         * @return
         */
        public E getOne() {
            List<E> list = mongoTemplate.aggregate(getAggregation(aggregationOperations), t, t).getMappedResults();
            return JLEmpty.check(list) ? list.get(0) : null;
        }

        /**
         * 查询集合
         *
         * @return
         */
        public List<E> list() {
            return mongoTemplate.aggregate(getAggregation(aggregationOperations), t, t).getMappedResults();
        }

        /**
         * 查询数量
         *
         * @return
         */
        public long count() {
            aggregationOperations.add(Aggregation.count().as(numStr));
            Aggregation aggregation = getAggregation(aggregationOperations);
            List<Map> countList = mongoTemplate.aggregate(aggregation, t, Map.class).getMappedResults();
            long count = !CollectionUtils.isEmpty(countList) ? Long.parseLong(countList.get(0).get(numStr).toString()) : 0;
            return count;
        }

        /**
         * 查询分页
         *
         * @param page
         * @param size
         * @return
         */
        public Page<E> page(int page, int size) {
            //查询数量
            List<AggregationOperation> countAggregationOperations = new ArrayList<>();
            countAggregationOperations.addAll(aggregationOperations);
            countAggregationOperations.add(Aggregation.count().as(numStr));
            Aggregation countAggregation = getAggregation(countAggregationOperations);
            List<Map> countList = mongoTemplate.aggregate(countAggregation, t, Map.class).getMappedResults();
            long count = !CollectionUtils.isEmpty(countList) ? Long.parseLong(countList.get(0).get(numStr).toString()) : 0;
            //查询数据
            SkipOperation skip = Aggregation.skip(page * size);
            LimitOperation limit = Aggregation.limit(size);
            aggregationOperations.add(skip);
            aggregationOperations.add(limit);
            Aggregation aggregation = getAggregation(aggregationOperations);
            List<E> list = mongoTemplate.aggregate(aggregation, t, t).getMappedResults();
            return PageableExecutionUtils.getPage(list, PageRequest.of(page, size), () -> count);
        }
    }
}
